import plotly.graph_objects as go
from copy import deepcopy
import numpy as np
import pandas as pd
import os
import glob
import sisl
import time
PLOTS_CONSTANTS = {
"spins": ["up", "down"],
"readFuncs": {
"fromH": lambda obj: obj._readfromH,
"siesOut": lambda obj: obj._readSiesOut
}
}
Configurable class, giving you the power to tweak parameters¶This class is a helper and should be the parent of any class that contains settings.
class Configurable:
def initSettings(self, **kwargs):
#Get the parameters of all the classes the object belongs to
self.params = []
for clss in type.mro(self.__class__):
if hasattr(clss, "_parameters"):
self.params = [*self.params, *clss._parameters]
#Define the settings dictionary, taking the value of each parameter from kwargs if it is there or from the defaults otherwise.
self.settings = { param["key"]: kwargs.get( param["key"], deepcopy(param["default"]) ) for param in self.params}
return self
def updateSettings(self, **kwargs):
#Initialize the settings in case there are none yet
if "settings" not in vars(self):
return self.initSettings(**kwargs)
#Otherwise, update them
for paramKey, paramValue in kwargs.items():
if paramKey in self.settings.keys():
#It is important to check this, because kwargs may contain other parameters that are not settings
self.settings[paramKey] = paramValue
return self
These are some decorators that will make your life extremely easy. They are meant to be used in methods of classes that inherit from Configurable.
#Run the method after having initialized the settings
def afterSettingsInit(method):
def updateAndExecute(obj, *args, **kwargs):
obj.initSettings(**kwargs)
return method(obj, *args, **kwargs)
return updateAndExecute
#Run the method and then initialize the settings
def beforeSettingsInit(method):
def updateAndExecute(obj, *args, **kwargs):
returns = method(obj, *args, **kwargs)
obj.initSettings(**kwargs)
return returns
return updateAndExecute
#Run the method after having updated the settings
def afterSettingsUpdate(method):
def updateAndExecute(obj, *args, **kwargs):
obj.updateSettings(**kwargs)
return method(obj, *args, **kwargs)
return updateAndExecute
#Run the method and then update the settings
def beforeSettingsUpdate(method):
def updateAndExecute(obj, *args, **kwargs):
returns = method(obj, *args, **kwargs)
obj.updateSettings(**kwargs)
return returns
return updateAndExecute
Plot, the parent class of all plots¶This class contains general methods and settings that can be applied to all plots. Every new plot class should inherit from Plot and call super().__init__() in their __init__() method, if they have any, to initialize things that are common to all plots.
class Plot(Configurable):
#These are the possible ways of reading data
_parameters = (
{
"key": "readingOrder",
"name": "Output reading/generating order",
"default": ("guiOut", "siesOut", "fromH")
},
{
"key": "rootFdf",
"name": "Path to fdf file",
"default": None
},
)
@afterSettingsInit
def __init__(self, **kwargs):
#Give an ID to the plot
self.id = time.time()
if self.settings["rootFdf"]:
#Set the other relevant files
self.setFiles()
#Try to read the hamiltonian
if "readHamiltonian" in kwargs.keys() and kwargs["readHamiltonian"]:
try:
self.setupHamiltonian()
except Exception:
log.warning("Unable to find or read {}.HSX".format(self.struct))
pass
#Process data in the required files, optimally to build a dataframe that can be queried afterwards
self.readData(**kwargs)
def __str__(self):
string = '''
Plot class: {} Plot type: {}
Settings:
{}
'''.format(
self.__class__.__name__,
self._plotType,
"\n".join([ "\t- {}: {}".format(key,value) for key, value in self.settings.items()])
)
return string
@afterSettingsUpdate
def setFiles(self, **kwargs):
'''
Checks if the required files are available and then builds a list with them
'''
#Set the fdfSile
rootFdf = self.settings["rootFdf"]
self.rootDir, fdfFile = os.path.split( rootFdf )
self.fdfSile = sisl.get_sile(rootFdf)
self.struct = self.fdfSile.get("SystemLabel")
#Check that the required files are there
#if RequirementsFilter().check(self.rootFdf, self.__class__.__name__ ):
if True:
#If they are there, we can confidently build this list
self.requiredFiles = [ os.path.join( self.rootDir, req.replace("$struct$", self.struct) ) for req in self.__class__._requirements["files"] ]
else:
log.error("\t the required files were not found, please check your file system.")
raise Exception("The required files were not found, please check your file system.")
return self
def show(self):
return self.figure
def merge(self, plotsToMerge, inplace = True, **kwargs):
'''
Merges this plot's instance with the list of plots provided (EXPERIMENTAL)
'''
#Make sure we deal with a list (user can provide a single plot)
if not isinstance(plotsToMerge, list):
plotsToMerge = [plotsToMerge]
if inplace:
for plot in plotsToMerge:
self.data = [*self.data, *plot.data]
return self
@afterSettingsUpdate
def setupHamiltonian(self):
'''
Sets up the hamiltonian for calculations with sisl.
'''
self.geom = self.fdfSile.read_geometry(output = True)
#Try to read the hamiltonian in two different ways
try:
#This one is favoured because it may read from TSHS file, which contains all the information of the geometry and basis already
self.H = self.fdfSile.read_hamiltonian()
except Exception:
Hsile = sisl.get_sile(os.path.join(self.rootDir, self.struct + ".HSX"))
self.H = Hsile.read_hamiltonian(geom = self.geom)
self.fermi = self.H.fermi_level()
def _readFromSources(self):
'''
Tries to read the data from the different possible sources in the order
determined by self.settings["readingOrder"].
'''
errors = []
#Try to read in the order specified by the user
for source in self.settings["readingOrder"]:
try:
#Get the reading function
readingFunc = PLOTS_CONSTANTS["readFuncs"][source](self)
#Execute it
data = readingFunc()
self.source = source
return data
except Exception as e:
errors.append("\t- {}: {}.{}".format(source, type(e).__name__, e))
else:
raise Exception("Could not read or generate data for {} from any of the possible sources.\n\n Here are the errors for each source:\n\n {} "
.format(self.__class__.__name__, "\n".join(errors)) )
Following this guide, you will not only build a very flexible plot class that you will be able to use in a wide range of cases , but also your class will be automatically recognized by the GUI. Therefore, you will get graphical interactivity for free.
First of all, all plot classes that you develop should inherit from the parent class Plot. In this way, they can profit from all the generic methods and processes that are implemented there. The Plot class is meant for you to write as little code as possible while still getting a powerful and dynamic representation.
So, if you were to define a new class to plot, let's say, your mum's weight, you would define it as class MumsWeightPlot(Plot):
If you are relatively new to python and you are having trouble understanding what's the point in this, you can find info on how class inheritance works in the following links: https://www.w3schools.com/python/python_inheritance.asp (written explanation), https://www.youtube.com/watch?v=Cn7AkDb4pIU (Youtube video)
BandsPlot, used to plot the band structure¶class BandsPlot(Plot):
'''
Plot representation of the bands.
'''
_plotType = "Bands"
_requirements = {
"files": ["$struct$.bands", "*.bands"]
}
_parameters = (
{
"key": "Erange" ,
"name": "Energy range",
"default": [-2,4],
"inputField": {
"type": "range",
"limits": [-10,10],
"displayValues": True,
"step": 0.1,
"marks": { **{ i: str(i) for i in range(-10,11) }, 0: "Ef",},
"updatemode": "drag",
"units": "eV",
"width": "offset-s1 s10"
},
"tooltip": {
"message": "Energy range where the bands are displayed. Default: [-2,4]",
"position": "top"
}
},
{
"key": "path" ,
"name": "Bands path",
"default": "0,0,0/100/0.5,0,0",
"inputField": {
"type": "textinput",
"placeholder": "Write your path here...",
"width": "offset-s1 offset-m1 m4 s10",
},
"tooltip": {
"message": '''Path along which bands are drawn in format:
<br>p1x,p1y,p1z/<number of points from P1 to P2>/p2x,p2y,p2z/...
<br>Default: 0,0,0/100/0.5,0,0''',
"position": "top"
}
},
{
"key": "ticks" ,
"name": "K ticks",
"default": "A,B",
"inputField": {
"type": "textinput",
"placeholder": "Write your ticks...",
"width": "offset-s1 offset-m1 m4 s10"
},
"tooltip": {
"message": "Ticks that should be displayed at the corners of the path (separated by commas). Default: A,B",
"position": "top"
}
},
{
"key": "lineColors",
"name": "Lines colors",
"default": ["black", "blue"],
"inputField": {
"type": "color",
"width": "offset-s1 offset-m1 m4 s10"
},
"tooltip": {
"message": "Choose the colors to display the bands.<br>The second one will only be used if the calculation is spin polarized.",
"position": "top"
}
},
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _readfromH(self):
#Get the path requested
self.path = self.settings["path"]
bandPoints, divisions = [], []
for item in self.path.split("/"):
splittedItem = item.split(",")
if splittedItem == [item]:
divisions.append(item)
elif len(splittedItem) == 3:
bandPoints.append(splittedItem)
bandPoints, divisions = np.array(bandPoints, dtype = float), np.array(divisions, dtype = int)
band = sisl.BandStructure(self.geom, bandPoints , divisions )
band.set_parent(self.H)
self.ticks = band.lineartick()
self.Ks = band.lineark()
self.kPath = band._k
bands = band.eigh()
return [bands]
def _readSiesOut(self):
#Get the info from the bands file
self.path = self.settings["path"] #This should be modified at some point, it's just so that setData works correctly
self.ticks, self.Ks, bands = sisl.get_sile(self.requiredFiles[0]).read_data()
self.fermi = 0.0 #Energies are already shifted
#Axes are switched so that the returned array is a list like [spinUpBands, spinDownBands]
return np.rollaxis(bands, 1)
@afterSettingsUpdate
def readData(self, updateFig = True, **kwargs):
'''
Gets the information for the bands plot and stores it into self.df
Returns
-----------
dataRead: boolean
whether data has been read succesfully or not
'''
#We try to read from the different sources using the _readFromSources method of the parent Plot class.
bands = self._readFromSources()
#Save the bands to dataframes so that we can easily query them
self.dfs = []
for spinComponentBands in bands:
df = pd.DataFrame(spinComponentBands)
#Set the column headers as strings instead of int (These are the wavefunctions numbers)
df.columns = df.columns.astype(str)
self.dfs.append(df)
if updateFig:
self.setData(updateFig = updateFig)
return self
@afterSettingsUpdate
def setData(self, updateFig = True, **kwargs):
'''
Converts the bands dataframe into a data object for plotly.
It stores the data under self.data, so that it can be accessed by posterior methods.
Returns
---------
self.data: list of dicts
contains a dictionary for each band with all its information.
'''
self.reqBandsDfs = []; self.data = []
for iSpin, df in enumerate(self.dfs):
#If the path has changed we need to produce the band structure again
if self.path != self.settings["path"]:
self.order = ["fromH"]
self.readData()
Erange = np.array(self.settings["Erange"]) + self.fermi
reqBandsDf = df[ df < Erange[1] + 3 ][ df > Erange[0] - 3 ].dropna(axis = 1, how = "all")
#Define the data of the plot as a list of dictionaries {x, y, 'type', 'name'}
self.data = [ *self.data, *[{
'x': self.Ks[~np.isnan(reqBandsDf[str(column)] - self.fermi)].tolist(),
'y': (reqBandsDf[str(column)] - self.fermi)[~np.isnan(reqBandsDf[str(column)] - self.fermi)].tolist(),
'mode': 'lines',
'name': "{} spin {}".format(int(column) + 1, PLOTS_CONSTANTS["spins"][iSpin]) if len(self.dfs) == 2 else str(int(column) + 1),
'line': {"color": self.settings["lineColors"][iSpin], 'width' : 1},
'hoverinfo':'name',
"hovertemplate": '%{y:.2f} eV',
} for column in reqBandsDf.columns ] ]
self.reqBandsDfs.append(reqBandsDf)
self.data = sorted(self.data, key = lambda x: x["name"])
if updateFig:
self.getFigure()
return self
@afterSettingsUpdate
def getFigure(self, **kwargs):
'''
Define the plot object using the actual data.
This method can be applied after updating the data so that the plot object is refreshed.
Returns
---------
self.plotObject.figure: go.Figure()
the updated version of the figure.
'''
self.figure = go.Figure({
'data': [go.Scatter(**lineData) for lineData in self.data],
'layout': {
'title': '{} band structure'.format(self.struct),
'showlegend': True,
'hovermode': 'closest',
'plot_bgcolor': "white",
'xaxis' : {
'title': 'K',
'showgrid': False,
'zeroline' : False,
'tickcolor': "black",
'ticklen': 5,
'tickvals': self.ticks[0],
'ticktext': self.settings["ticks"].split(",") if self.source != "siesOut" else self.ticks[1]
},
'yaxis' : {
'title': 'E - E<sub>f</sub> (eV)',
'showgrid': False,
'range': self.settings["Erange"],
'tickcolor': "white",
'ticklen': 10
}
}
})
return self.figure
Here comes the fun part!
You can initialize the plot without any settings your_plot = PlotClass(), and you will get the defaults:
One can always check the current settings either by looking at your_plot.settings or by printing the plot instance: print(your_plot).
bp = BandsPlot()
print(bp)
Then, you could provide an fdf file and read the data from it. Note that at any point (also on initialization) you can pass any setting to the methods and the settings will get updated before executing the method.
bp.setFiles(rootFdf = "/path/to/fdf")
The readData(), setData() and getFigure() methods are the three main methods of a plot class. They are meant to be called in that order during the analysis flow. By default, if you call readData() or setData() the changes will propagate forward, but you could always stop that from happening by passing the updateFig = false argument.
bp.readData(lineColors = ["green", "yellow"])
bp.show()
bp.setData(lineColors=["red", "black"], updateFig = False)
bp.show()
bp.getFigure()
Are you struggling to compare two plots? Well, you are in luck! You can take profit of the merge method of the Plot class to merge as much plots as you wish. At the current moment, however, you will neet to apply all the settings you need before merging, as afterwards you won't be able to set settings for each plot separately.
bp1 = BandsPlot(rootFdf = "/path/to/first/fdf",
lineColors = ["black"])
bp2 = BandsPlot(rootFdf = "/path/to/second/fdf",
lineColors = ["red"])
bp3 = BandsPlot(rootFdf = "/path/to/third/fdf",
lineColors = ["green"])
bp1.merge([bp2, bp3]).getFigure()